import logging
import math

from collections import defaultdict
from lab.reports import Table, CellFormatter
from downward.reports import PlanningReport

def toString(f):
    return str("{0:.2f}".format(f)) 

def valid(attribute):
    return (attribute != -1 and attribute is not None) 

class WVCReport(PlanningReport):
    """
    If the experiment contains more than one algorithm, use
    ``filter_algorithm='my_algorithm'`` to select exactly one algorithm
    for the report.

    >>> from downward.experiment import FastDownwardExperiment
    >>> exp = FastDownwardExperiment()
    >>> exp.add_report(SummaryReport(
    ...     attributes=["p_star", "c_star","wvc","error"],
    ...     filter_algorithm=["computeWVC"]))

    """

    def __init__(self, **kwargs):
        PlanningReport.__init__(self, **kwargs)

    def _get_table(self, domain):
        kwargs = dict(
            colored=True)
        table = Table(title=domain, **kwargs)
        pstar = []
        
        for problem in self.domains[domain]:
            run = self.runs[domain,problem,self.algo]
            if (run.get('error','False') == True):
                print 'Run: "' + problem + '" containes errors.'
                continue
            pstar.append(run.get('p_star',-1))
            row = domain + ':' + problem
            table.add_cell(row, 'C*', run.get('c_star',-1))
            table.add_cell(row, 'p*', run.get('p_star',-1))
            table.add_cell(row, 'WVC', run.get('wvc',-1))

        pstar = sorted(list(filter(lambda x: x >= 0, pstar)))
        if (len(pstar) < 1):
            return table
        avg = reduce((lambda x,y: x+y), pstar) / len(pstar)
        std = math.sqrt(reduce((lambda x,y: x+y), map((lambda x: abs(x-avg)**2), pstar)) / len(pstar))
        self.pstar[domain] = (pstar[0], avg, pstar[-1], std , len(pstar))
        return table

    def _print_pstar(self):
        for domain in sorted(self.domains.keys()):
            (a, b, c, d, e) = self.pstar.get(domain,(-1,-1,-1,-1,-1))
            print domain + ' & ' + toString(a) + ' & ' + toString(b) + ' & ' + toString(c) + ' & ' + toString(d) + ' & ' + str(e) +' \\\\'

    def _print_avg_pstar(self):
        for domain in sorted(self.domains.keys()):
            (a, b, c, d, e) = self.pstar.get(domain,(-1,-1,-1,-1,-1))
            print toString(b)

    def _print_pstar_data(self):
        values = {}
        i = 0
        maxJ = 0
        for domain in sorted(self.domains.keys()):
            j = 0
            for problem in self.domains[domain]:
                run = self.runs[domain,problem,self.algo]
                pstar = run.get('p_star','nan')
                values[(i,j)] = pstar
                j += 1
            i += 1
            maxJ = j if j > maxJ else maxJ
        for i2 in range(0,i):
            out = []
            for j2 in range(0,maxJ):
                out.append(values.get((i2,j2),'nan'))
            print ' '.join(str(val) for val in out)


    def _check_algorithms(self):
        if len(self.algorithms) != 1:
            logging.critical('WVC Reports need exactly one algorithm.')                
        else:
            self.algo = self.algorithms[0]
        
    def get_markup(self):
        self._check_algorithms()
        
        self.pstar = {}
        tables = [self._get_table(domain) for domain in sorted(self.domains.keys())]
        self._print_pstar()
        self._print_avg_pstar()
        #self._print_pstar_data()
        return '\n'.join(str(table) for table in tables)


# List of properties:
# domains               : (domain), problems
# problems              : set (domain, problem)
# problem_runs          : (domain, problem), runs
# domain_algorithm_runs : (domain, algorithm), runs
# runs                  : (domain, problem, algo), run
# attributes
# algorithms            : set (algorithm)
# algorithm_info
